# coding=utf-8
# Copyright 2024 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""RNN text classification."""

import numpy as np
import sonnet as snt

from task_set import registry
from task_set.tasks import base
from task_set.tasks import utils

import tensorflow.compat.v1 as tf


@registry.task_registry.register_sampler("rnn_text_classification_family")
def sample_rnn_text_classification_family_cfg(seed):
  """Samples a task config for an RNN classification task.

  These configs are nested python structures that provide enough information
  to create an instance of the problem.

  Args:
    seed: int Random seed to generate task from.

  Returns:
    A nested dictionary containing a configuration.
  """

  rng = np.random.RandomState(seed)
  cfg = {}

  cfg["embed_dim"] = utils.sample_log_int(rng, 8, 128)
  cfg["w_init"] = utils.sample_initializer(rng)
  cfg["dataset"] = utils.sample_text_dataset(rng)

  # TODO(lmetz) trim this if using characters...
  cfg["vocab_size"] = utils.sample_log_int(rng, 100, 10000)

  cfg["core"] = utils.sample_rnn_core(rng)
  cfg["trainable_init"] = bool(rng.choice([True, False]))
  cfg["loss_compute"] = rng.choice(["last", "avg", "max"])
  return cfg


@registry.task_registry.register_getter("rnn_text_classification_family")
def get_rnn_text_classification_family(cfg):
  """Get a task for the given cfg.

  Args:
    cfg: config specifying the model generated by
      `sample_rnn_text_classification_family_cfg`.

  Returns:
    base.BaseTask for the given config.
  """

  w_init = utils.get_initializer(cfg["w_init"])
  init = {"w": w_init}

  def _build(batch):
    """Build the sonnet module.

    Args:
      batch: A dictionary with keys "label", "label_onehot", and "text" mapping
        to tensors. The "text" consists of int tokens. These tokens are
        truncated to the length of the vocab before performing an embedding
        lookup.

    Returns:
      Loss of the batch.
    """
    vocab_size = cfg["vocab_size"]
    max_token = cfg["dataset"][1]["max_token"]
    if max_token:
      vocab_size = min(max_token, vocab_size)

    # Clip the max token to be vocab_size-1.
    tokens = tf.minimum(
        tf.to_int32(batch["text"]),
        tf.to_int32(tf.reshape(vocab_size - 1, [1, 1])))

    embed = snt.Embed(vocab_size=vocab_size, embed_dim=cfg["embed_dim"])
    embedded_tokens = embed(tokens)
    rnn = utils.get_rnn_core(cfg["core"])

    batch_size = tokens.shape.as_list()[0]

    state = rnn.initial_state(batch_size, trainable=cfg["trainable_init"])

    outputs, _ = tf.nn.dynamic_rnn(rnn, embedded_tokens, initial_state=state)
    if cfg["loss_compute"] == "last":
      rnn_output = outputs[:, -1]  # grab the last output
    elif cfg["loss_compute"] == "avg":
      rnn_output = tf.reduce_mean(outputs, 1)  # average over length
    elif cfg["loss_compute"] == "max":
      rnn_output = tf.reduce_max(outputs, 1)
    else:
      raise ValueError("Not supported loss_compute [%s]" % cfg["loss_compute"])

    logits = snt.Linear(
        batch["label_onehot"].shape[1], initializers=init)(
            rnn_output)

    loss_vec = tf.nn.softmax_cross_entropy_with_logits_v2(
        labels=batch["label_onehot"], logits=logits)
    return tf.reduce_mean(loss_vec)

  datasets = utils.get_text_dataset(cfg["dataset"])
  return base.DatasetModelTask(lambda: snt.Module(_build), datasets)
